{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b2748b09-6958-4cb0-abe9-7b76fdacedac",
   "metadata": {},
   "source": [
    "## Structured generation with LLM(1)：Kor"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bb883b1b-cabc-4b27-a405-8d99232165c1",
   "metadata": {},
   "source": [
    "第一步，安装必要的库 `pip3 install kor langchain pydantic`\n",
    "\n",
    "然后，导入它们"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "facf43e8-a12d-407c-b443-2aebb97501d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from langchain_community.chat_models import ChatOpenAI\n",
    "from kor import create_extraction_chain, Object, Text, Number, from_pydantic\n",
    "import pydantic\n",
    "from pydantic import BaseModel, Field\n",
    "from typing import List, Optional\n",
    "import enum\n",
    "from pprint import pprint"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ca8a9322-1692-4139-9b7e-f19d84fe89c9",
   "metadata": {},
   "source": [
    "配置openai api的base_url和key，国内许多厂商都适配openai api。\n",
    "\n",
    "安利一个国内的硅基流动，他们家**免费**提供了一些9B以下的优秀开源LLM api，而且也适配openai api；免费的模型列表看这里：https://docs.siliconflow.cn/docs/model-names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e66ee194-8267-47d4-b01b-9f6e75189fd2",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/duanyu/.conda/envs/Study/lib/python3.10/site-packages/langchain_core/_api/deprecation.py:139: LangChainDeprecationWarning: The class `ChatOpenAI` was deprecated in LangChain 0.0.10 and will be removed in 0.3.0. An updated version of the class exists in the langchain-openai package and should be used instead. To use it run `pip install -U langchain-openai` and import as `from langchain_openai import ChatOpenAI`.\n",
      "  warn_deprecated(\n"
     ]
    }
   ],
   "source": [
    "os.environ['OPENAI_BASE_URL'] = \"https://api.siliconflow.cn/v1\" # 本次实验使用了硅基流动的免费api\n",
    "os.environ['OPENAI_API_KEY'] = \"xxx\" # 替换成自己申请的api key\n",
    "\n",
    "llm = ChatOpenAI(\n",
    "    model_name='THUDM/glm-4-9b-chat',\n",
    "    temperature=0, # 设置成0最稳定；structured generation中稳定最重要\n",
    "    max_tokens=2000,\n",
    "    model_kwargs = {\n",
    "        'frequency_penalty':0,\n",
    "        'presence_penalty':0,\n",
    "        'top_p':1.0,\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "16a575ff-c9e2-4529-b5ca-a2a98ee40065",
   "metadata": {},
   "source": [
    "## Example 1: 中文翻译器\n",
    "\n",
    "效果：输入任意文本，返回{\"translate_result\": {\"chinese\": 翻译结果}}\n",
    "\n",
    "在结构化输出中，一般只需两步即可：\n",
    "\n",
    "1. 设置schema（即想要llm输出的结构，同时包含注释、例子）；\n",
    "2. 用结构化输出工具（例如本文提到的Kor）得到schema结果。\n",
    "\n",
    "Kor支持两种设置schema的模式，*Kor schema*和*Pydantic Model*，在这个例子中，我们使用Kor schema。\n",
    "\n",
    "**注意**：此处不对Kor做过多介绍，细节请读者参阅文档：https://eyurtsev.github.io/kor/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ebc288a8-4aca-47db-8d7b-76372fc67a32",
   "metadata": {},
   "outputs": [],
   "source": [
    "# kor schema，我们想要的输出格式\n",
    "schema = Object(\n",
    "    id=\"translate_result\",\n",
    "    description=(\n",
    "        \"任意文本的翻译结果。\"\n",
    "    ),\n",
    "    attributes=[\n",
    "        Text(\n",
    "            id=\"chinese\",\n",
    "            description=\"中文翻译结果\",\n",
    "            examples=[], # Kor支持few-shot examples，但本例子比较简单，故不需要\n",
    "            many=False, \n",
    "        ),\n",
    "    ],\n",
    "    many=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "dd6d2408-4cb3-4bcf-b0cc-3b714db85f78",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/duanyu/.conda/envs/Study/lib/python3.10/site-packages/langchain_core/_api/deprecation.py:139: LangChainDeprecationWarning: The class `LLMChain` was deprecated in LangChain 0.1.17 and will be removed in 1.0. Use RunnableSequence, e.g., `prompt | llm` instead.\n",
      "  warn_deprecated(\n",
      "/Users/duanyu/.conda/envs/Study/lib/python3.10/site-packages/langchain_core/_api/deprecation.py:139: LangChainDeprecationWarning: The method `Chain.run` was deprecated in langchain 0.1.0 and will be removed in 0.3.0. Use invoke instead.\n",
      "  warn_deprecated(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'translate_result': {'chinese': '我们训练了一个基于GPT-4的模型，称为CriticGPT，用于捕捉ChatGPT代码输出的错误。我们发现，当人们从CriticGPT那里获得帮助来审查ChatGPT代码时，他们比没有帮助的人高出60%的效率。我们正在开始将类似CriticGPT的模型集成到我们的RLHF标记流程中，为我们的训练师提供明确的AI辅助。这是朝着能够评估来自高级AI系统的输出迈出的一步，这些输出在没有更好的工具的情况下很难被人类评估。'}}\n"
     ]
    }
   ],
   "source": [
    "# 运行结果\n",
    "chain = create_extraction_chain(llm, schema, encoder_or_encoder_class='json')\n",
    "text = \"We've trained a model, based on GPT-4, called CriticGPT to catch errors in ChatGPT's code output. We found that when people get help from CriticGPT to review ChatGPT code they outperform those without help 60% of the time. We are beginning the work to integrate CriticGPT-like models into our RLHF labeling pipeline, providing our trainers with explicit AI assistance. This is a step towards being able to evaluate outputs from advanced AI systems that can be difficult for people to rate without better tools.\"\n",
    "print(chain.run(text)['data'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ec15115-6434-41ac-9c96-3b6b8ae4efd6",
   "metadata": {},
   "source": [
    "**示例1成功运行：）**\n",
    "\n",
    "我们打印kor的prompt来看看。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "c9f79264-bef6-4f62-8c11-dba1b1748eb8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Your goal is to extract structured information from the user's input that matches the form described below. When extracting information please make sure it matches the type information exactly. Do not add any attributes that do not appear in the schema shown below.\n",
      "\n",
      "```TypeScript\n",
      "\n",
      "translate_result: { // 任意文本的翻译结果。\n",
      " chinese: string // 中文翻译结果\n",
      "}\n",
      "```\n",
      "\n",
      "\n",
      "Please output the extracted information in JSON format. Do not output anything except for the extracted information. Do not add any clarifying information. Do not add any fields that are not in the schema. If the text contains attributes that do not appear in the schema, please ignore them. All output must be in JSON format and follow the schema specified above. Wrap the JSON in <json> tags.\n",
      "\n",
      "\n",
      "\n",
      "Input: [user input]\n",
      "Output:\n"
     ]
    }
   ],
   "source": [
    "print(chain.prompt.format_prompt(text=\"[user input]\").to_string())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6ef20dd0-b2fd-47b7-9aad-5e1161781661",
   "metadata": {},
   "source": [
    "接着，我们进入第二个示例。\n",
    "\n",
    "## Example 2：评价解析\n",
    "\n",
    "预期效果：输入一段用户评价，得到评价属性（口味、价格等）、评价极性（正向、负向、中立）、评价词（好吃、贵等）、参考片段。\n",
    "\n",
    "结构化输出，第一步是定义schema，我们可以设置成这样的schema\n",
    "\n",
    "```\n",
    "[\n",
    "    {\n",
    "        'aspect': 评价属性,\n",
    "        'sentiment': 评价极性,\n",
    "        'sentiment_word': 评价词,\n",
    "        'reference': 参考片段,\n",
    "    }\n",
    "]\n",
    "```\n",
    "\n",
    "在这个例子中，我们使用*Pydantic Model*来定义schema，*Pydantic Model*也能够支持few-shot examples，其额外好处是可以Validate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "bcb04f98-6285-4350-8694-e3078a5ce338",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 评价解析的pydantic model\n",
    "class Sentiment(enum.Enum):\n",
    "    positive = \"positive\"\n",
    "    negative = \"negative\"\n",
    "    neural = \"neural\"\n",
    "\n",
    "class Dianpin(BaseModel):\n",
    "    aspect: str = Field(\n",
    "        description=\"评价属性\"\n",
    "    )\n",
    "    sentiment_word: str = Field(\n",
    "        description='对评价属性的评价词，从原文中抽取',\n",
    "    )\n",
    "    sentiment: Optional[Sentiment] = Field(\n",
    "        description='对评价属性的情感，positive\\negative\\neural中的一个',\n",
    "    )\n",
    "    reference: str = Field(\n",
    "        description='评价的原文'\n",
    "    ) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "b1d9c8c7-c1d0-4dcd-a62d-40ff23a2ef4e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'data': {},\n",
      " 'errors': [ParseError('The LLM has returned structured data which does not match the expected schema. Providing additional examples may help improve the parse.')],\n",
      " 'raw': '\\n'\n",
      "        '<json>\\n'\n",
      "        '[\\n'\n",
      "        '  {\\n'\n",
      "        '    \"aspect\": \"环境\",\\n'\n",
      "        '    \"sentiment_word\": \"可以\",\\n'\n",
      "        '    \"sentiment\": \"positive\"\\n'\n",
      "        '  },\\n'\n",
      "        '  {\\n'\n",
      "        '    \"aspect\": \"味道\",\\n'\n",
      "        '    \"sentiment_word\": \"还不错\",\\n'\n",
      "        '    \"sentiment\": \"positive\"\\n'\n",
      "        '  },\\n'\n",
      "        '  {\\n'\n",
      "        '    \"aspect\": \"价格\",\\n'\n",
      "        '    \"sentiment_word\": \"小贵\",\\n'\n",
      "        '    \"sentiment\": \"negative\"\\n'\n",
      "        '  }\\n'\n",
      "        ']\\n'\n",
      "        '</json>',\n",
      " 'validated_data': {}}\n"
     ]
    }
   ],
   "source": [
    "# 运行kor\n",
    "schema, validator = from_pydantic(\n",
    "    Dianpin, \n",
    "    description='对评价的解析结果', \n",
    "    examples=[],  \n",
    "    many=True #支持list of aspect\n",
    ")\n",
    "chain = create_extraction_chain(\n",
    "    llm, schema, encoder_or_encoder_class=\"json\", validator=validator\n",
    ")\n",
    "\n",
    "pprint(chain.run(\"整体来说，环境可以，味道的话也还不错，但价格有一点小贵。\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4159a717-f904-43e8-9e19-4da3744f42cc",
   "metadata": {},
   "source": [
    "**注意**，此时`data`字段数据为空，**因为LLM的返回不符合预期的schema**，kor建议加入examples\n",
    "\n",
    "于是我们加入一个简单的example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "b0e4d698-bc58-4629-a3a9-0b279206108d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'data': {'dianpin': [{'aspect': '环境',\n",
      "                       'reference': '整体来说，环境可以',\n",
      "                       'sentiment': 'positive',\n",
      "                       'sentiment_word': '可以'},\n",
      "                      {'aspect': '味道',\n",
      "                       'reference': '味道的话也还不错',\n",
      "                       'sentiment': 'positive',\n",
      "                       'sentiment_word': '还不错'},\n",
      "                      {'aspect': '价格',\n",
      "                       'reference': '但价格有一点小贵',\n",
      "                       'sentiment': 'negative',\n",
      "                       'sentiment_word': '小贵'}]},\n",
      " 'errors': [],\n",
      " 'raw': '\\n'\n",
      "        '<json>\\n'\n",
      "        '{\\n'\n",
      "        '  \"dianpin\": [\\n'\n",
      "        '    {\\n'\n",
      "        '      \"aspect\": \"环境\",\\n'\n",
      "        '      \"sentiment_word\": \"可以\",\\n'\n",
      "        '      \"sentiment\": \"positive\",\\n'\n",
      "        '      \"reference\": \"整体来说，环境可以\"\\n'\n",
      "        '    },\\n'\n",
      "        '    {\\n'\n",
      "        '      \"aspect\": \"味道\",\\n'\n",
      "        '      \"sentiment_word\": \"还不错\",\\n'\n",
      "        '      \"sentiment\": \"positive\",\\n'\n",
      "        '      \"reference\": \"味道的话也还不错\"\\n'\n",
      "        '    },\\n'\n",
      "        '    {\\n'\n",
      "        '      \"aspect\": \"价格\",\\n'\n",
      "        '      \"sentiment_word\": \"小贵\",\\n'\n",
      "        '      \"sentiment\": \"negative\",\\n'\n",
      "        '      \"reference\": \"但价格有一点小贵\"\\n'\n",
      "        '    }\\n'\n",
      "        '  ]\\n'\n",
      "        '}\\n'\n",
      "        '</json>',\n",
      " 'validated_data': [Dianpin(aspect='环境', sentiment_word='可以', sentiment=<Sentiment.positive: 'positive'>, reference='整体来说，环境可以'),\n",
      "                    Dianpin(aspect='味道', sentiment_word='还不错', sentiment=<Sentiment.positive: 'positive'>, reference='味道的话也还不错'),\n",
      "                    Dianpin(aspect='价格', sentiment_word='小贵', sentiment=<Sentiment.negative: 'negative'>, reference='但价格有一点小贵')]}\n"
     ]
    }
   ],
   "source": [
    "# 运行kor\n",
    "schema, validator = from_pydantic(\n",
    "    Dianpin, \n",
    "    description='对评价的解析结果', \n",
    "    examples=[\n",
    "        ('味道真不错，下次还来！', [{\"aspect\":\"味道\", \"sentiment_word\": \"真不错\", \"sentiment\": \"positive\", \"reference\": \"味道真不错\"}])\n",
    "    ],\n",
    "    many=True #支持list of aspect\n",
    ")\n",
    "chain = create_extraction_chain(\n",
    "    llm, schema, encoder_or_encoder_class=\"json\", validator=validator\n",
    ")\n",
    "\n",
    "pprint(chain.run(\"整体来说，环境可以，味道的话也还不错，但价格有一点小贵。\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "21697aeb-7d69-4a3b-bd2d-1ea98e8c75c1",
   "metadata": {},
   "source": [
    "**示例2也成功运行啦！**\n",
    "\n",
    "我们也打印kor的prompt，看看长什么样，以及few-shot examples是如何使用的。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "bccee251-d0e0-4086-b48f-ee79558969fa",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Your goal is to extract structured information from the user's input that matches the form described below. When extracting information please make sure it matches the type information exactly. Do not add any attributes that do not appear in the schema shown below.\n",
      "\n",
      "```TypeScript\n",
      "\n",
      "dianpin: Array<{ // 对评价的解析结果\n",
      " aspect: string // 评价属性\n",
      " sentiment_word: string // 对评价属性的评价词，从原文中抽取\n",
      " sentiment: \"positive\" | \"negative\" | \"neural\" // 对评价属性的情感，positive\n",
      "egative\n",
      "eural中的一个\n",
      " reference: string // 评价的原文\n",
      "}>\n",
      "```\n",
      "\n",
      "\n",
      "Please output the extracted information in JSON format. Do not output anything except for the extracted information. Do not add any clarifying information. Do not add any fields that are not in the schema. If the text contains attributes that do not appear in the schema, please ignore them. All output must be in JSON format and follow the schema specified above. Wrap the JSON in <json> tags.\n",
      "\n",
      "\n",
      "\n",
      "Input: 味道真不错，下次还来！\n",
      "Output: <json>{\"dianpin\": [{\"aspect\": \"味道\", \"sentiment_word\": \"真不错\", \"sentiment\": \"positive\", \"reference\": \"味道真不错\"}]}</json>\n",
      "Input: [user input]\n",
      "Output:\n"
     ]
    }
   ],
   "source": [
    "print(chain.prompt.format_prompt(text=\"[user input]\").to_string())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36a6af73-53df-4ff6-8b6f-c4d6286342cf",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88fc9e50-5356-43a6-a6c7-89a5fb534250",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdbe2335-f51c-49c1-b4ac-05034e6278e4",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "study",
   "language": "python",
   "name": "study"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
